#pragma once

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cassert>

#include "ln.h"

////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

inline void check_cuda_(cudaError_t status, const char *file, int line) {
  if (status != cudaSuccess) {
    fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
    exit(status);
  }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

#define CHECK_CUDA(ans) \
  { check_cuda_((ans), __FILE__, __LINE__); }

////////////////////////////////////////////////////////////////////////////////////////////////////

#define DIVUP(x, y) (((x) + ((y)-1)) / (y))

////////////////////////////////////////////////////////////////////////////////////////////////////

#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
  void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params,           \
                                                                    const bool configure_params) {                    \
    launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>(        \
        launch_params, configure_params);                                                                             \
  }                                                                                                                   \
  static FwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE>                                                        \
      reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(                                                      \
          ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)

////////////////////////////////////////////////////////////////////////////////////////////////////

#define REGISTER_BWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \
                              BYTES_PER_LDG_FINALIZE)                                                                 \
  void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params,           \
                                                                    const bool configure_params) {                    \
    launch_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG,         \
            BYTES_PER_LDG_FINALIZE>(launch_params, configure_params);                                                 \
  }                                                                                                                   \
  static BwdRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE>                                                        \
      reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(                                                      \
          ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ float2 operator+(const float2 &a, const float2 &b) { return {a.x + b.x, a.y + b.y}; }

////////////////////////////////////////////////////////////////////////////////////////////////////

inline __device__ void operator+=(float2 &a, const float2 &b) {
  a.x += b.x;
  a.y += b.y;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
struct Sum {
  inline __device__ Sum() {}
  inline __device__ T operator()(const T &a, const T &b) { return a + b; }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
inline __device__ T warp_shuffle_xor(const T &x, uint32_t idx) {
  return __shfl_xor_sync(uint32_t(-1), x, idx);
}

template <>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)};
}

template <typename T>
inline __device__ T warp_shuffle_down(const T &x, uint32_t idx) {
  return __shfl_down_sync(uint32_t(-1), x, idx);
}

template <>
inline __device__ float2 warp_shuffle_down<float2>(const float2 &x, uint32_t idx) {
  return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)};
}

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace layer_norm {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint16 {
  uint4 u;
  uint4 v;
  uint4 s;
  uint4 t;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct uint8 {
  uint4 u;
  uint4 v;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int BYTES>
struct BytesToType {};

template <>
struct BytesToType<64> {
  using Type = uint16;
  static_assert(sizeof(Type) == 64);
};

template <>
struct BytesToType<32> {
  using Type = uint8;
  static_assert(sizeof(Type) == 32);
};

template <>
struct BytesToType<16> {
  using Type = uint4;
  static_assert(sizeof(Type) == 16);
};

template <>
struct BytesToType<8> {
  using Type = uint64_t;
  static_assert(sizeof(Type) == 8);
};

template <>
struct BytesToType<4> {
  using Type = uint32_t;
  static_assert(sizeof(Type) == 4);
};

template <>
struct BytesToType<2> {
  using Type = uint16_t;
  static_assert(sizeof(Type) == 2);
};

template <>
struct BytesToType<1> {
  using Type = uint8_t;
  static_assert(sizeof(Type) == 1);
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
struct TypeToVec2 {};

template <>
struct TypeToVec2<float> {
  using Type = float2;
};

template <>
struct TypeToVec2<half> {
  using Type = half2;
};

template <>
struct TypeToVec2<nv_bfloat16> {
  using Type = nv_bfloat162;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int INDEX>
struct Get {
  template <typename T, typename R>
  static inline __device__ R of(const T &vec);
};

template <>
template <typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) {
  return vec.x;
}

template <>
template <typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) {
  return vec.y;
}

template <>
template <typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) {
  return vec.z;
}

template <>
template <typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) {
  return vec.w;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Src, typename Dst>
struct Converter {
  static inline __device__ Dst convert(const Src &from) { return Dst(from); }
};

template <>
struct Converter<float2, half2> {
  static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); }
};

template <>
struct Converter<float2, nv_bfloat162> {
  static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
    return __float22bfloat162_rn(x);
#else
    union {
      nv_bfloat162 raw;
      nv_bfloat16 x;
      nv_bfloat16 y;
    } tmp;
    tmp.x = __float2bfloat16_rn(x.x);
    tmp.y = __float2bfloat16_rn(x.y);
    return tmp.raw;
#endif
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
struct Zeros {
  static inline __device__ T get() { return T(0.f); }
};

template <>
struct Zeros<float2> {
  static inline __device__ float2 get() { return make_float2(0.f, 0.f); }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Elt_type, uint32_t NUM_ELT>
struct Vec {
  enum { BYTES = NUM_ELT * sizeof(Elt_type) };

  using Vec_type = typename BytesToType<BYTES>::Type;

  using Alias_type = union {
    Vec_type vec;
    Elt_type elt[NUM_ELT];
  };

  Alias_type data;

  template <typename S>
  inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      other.data.elt[it] = S(this->data.elt[it]);
    }
  }

  template <typename Op>
  inline __device__ void assign(const Op &op) {
#pragma unroll
    for (int it = 0; it < NUM_ELT; it++) {
      this->data.elt[it] = op(it);
    }
  }

  inline __device__ void load_from(const void *base_ptr, const size_t idx) {
    this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
  }

  inline __device__ void store_to(void *base_ptr, const size_t idx) {
    static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <uint32_t CTAS_PER_ROW>
struct InterCTASync {
  template <typename Params>
  inline __device__ InterCTASync(Params &params, uint32_t bidm, uint32_t bidn)
      : phase_counter_(0),
        b0_(params.barrier + bidm)  // The barrier for this group of CTAs.
        ,
        b1_(params.barrier + bidm + params.ctas_per_col)  // The barrier for this group of CTAs.
  {
    // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
  }

  inline __device__ void spin_wait_(int *barrier, int step, int expected) {
    asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
    for (int found = -1; found != expected;) {
      asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
    }
  }

  inline __device__ void sync() {
    // ALL THREADS MUST ENTER!

    // We switch barrier every iteration.
    int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
    // We decrement every other iteration.
    bool dec = phase_counter_ & 0x2;
    int step = dec ? -1 : 1;
    int expected = dec ? 0 : CTAS_PER_ROW;
    // There are only 4 phases: up/down for b0/b1.
    phase_counter_ = (phase_counter_ + 1) & 0x3;

    if (threadIdx.x == 0) {
      spin_wait_(barrier, step, expected);
    }
    // CTA waits for thread 0
    __syncthreads();
  }

  int phase_counter_;
  int *b0_;
  int *b1_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
  using InterCTASync = InterCTASync<CTAS_PER_ROW>;
  using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
  using Type = typename Base::Type;

  enum { SMEM_BYTES = Base::SMEM_BYTES };

  enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
  enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };

  // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
  enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };

  template <typename Params>
  inline __device__ Reducer(Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n,
                            uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem),
        inter_cta_(params, bidm, bidn),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<T *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) {}

  template <typename Op>
  inline __device__ T allreduce(T data, Op &op) {
    data = Base::reduce(data, op);
    // We switch workspace every iteration.
    T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    // Warp leaders 0 hold the CTA-local results.
    if (this->warp_n_ == 0 && this->lane_ == 0) {
      workspace[bidn_] = data;
    }
    inter_cta_.sync();
    static_assert(CTAS_PER_ROW <= 32);
    T total = Zeros<T>::get();
    if (this->lane_ < CTAS_PER_ROW) {
      total = workspace[this->lane_];
    }
    total = Reducer<T, 1, 1, 1>::allreduce_(total, op);

    return total;
  }

  InterCTASync inter_cta_;

  T *w0_;
  T *w1_;
  int bidn_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> {
  using Type = T;
  enum { SMEM_BYTES = 0 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };

  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n,
                            uint32_t lane, void *smem)
      : warp_n_(warp_n), lane_(lane) {}

  template <typename Op>
  static inline __device__ T allreduce_(T data, Op &op) {
#pragma unroll
    for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
      data = op(data, warp_shuffle_xor(data, it));
    }
    return data;
  }

  template <typename Op>
  inline __device__ T allreduce(T data, Op &op) {
    return allreduce_(data, op);
  }

  template <typename Op>
  inline __device__ T reduce(T data, Op &op) {
// only lane 0 holds the result!
#pragma unroll
    for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) {
      data = op(data, warp_shuffle_down(data, it));
    }
    return data;
  }
  int warp_n_;
  int lane_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
  using Base = Reducer<T, 1, WARPS_M, 1>;

  using Type = T;

  enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
  enum { WORKSPACE_BYTES_PER_GROUP = 0 };

  enum { THREADS_PER_WARP = 32 };

  template <typename Params>
  inline __device__ Reducer(Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n,
                            uint32_t lane, void *smem)
      : Base(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <typename Op>
  inline __device__ T allreduce(T data, Op &op) {
    T *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
    }
    __syncthreads();
    T out = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < WARPS_N; it++) {
      out = op(out, smem[it]);
    }
    return out;
  }

  template <typename Op>
  inline __device__ T reduce(T data, Op &op) {
    T *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // only intra-CTA group leader holds the result!
    data = Base::reduce(data, op);
    if (this->lane_ == 0) {
      smem[this->warp_n_] = data;
    }
    __syncthreads();
    T out = Zeros<T>::get();
    if (this->warp_n_ == 0 && this->lane_ == 0) {
#pragma unroll
      for (int it = 0; it < WARPS_N; it++) {
        out = op(out, smem[it]);
      }
    }
    return out;
  }

  T *smem0_;
  T *smem1_;
  bool use0_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active) {
  // Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
  int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);

#pragma unroll
  for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) {
    // Exchange
    T n_b = warp_shuffle_down(n_a, step);
    T m_b = warp_shuffle_down(m_a, step);
    T m2_b = warp_shuffle_down(m2_a, step);

    // Update
    const T n_ab = n_a + n_b;    // We can handle one of them being 0, not both.
    const T rn_ab = 1.f / n_ab;  // Might have different n per thread, otherwise this would simplify :(
    const T delta = m_a - m_b;
    const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
    const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;

    n_a = n_ab;
    m_a = m_ab;
    m2_a = m2_ab;
  }
  // Intra-warp broadcast (only lane 0 has valid stats).
  m_a = __shfl_sync(uint32_t(-1), m_a, 0);
  m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats {
  // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.

  using InterCTASync = InterCTASync<CTAS_PER_ROW>;
  using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
  using stats_t = typename BlockStats::stats_t;

  enum { SMEM_BYTES = BlockStats::SMEM_BYTES };

  template <typename Params>
  inline __device__ Stats(Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane,
                          void *smem)
      : inter_cta_(params, bidm, bidn),
        block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem),
        bidn_(bidn)  // CTA id within the group.
        ,
        w0_(static_cast<stats_t *>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW),
        w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW),
        warp_n_(warp_n),
        lane_(lane) {}

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
    // TODO rn is not really needed here..
    constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
    stats_t block_stats = block_stats_.compute(elts, block_rn);

    stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;

    if (warp_n_ == 0 && lane_ == 0) {
      workspace[bidn_] = block_stats;
    }

    // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
    inter_cta_.sync();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume CTA group size in N less than 32, such that we can finalize with a single warp.
    static_assert(CTAS_PER_ROW <= 32);

    // Every warp does the final reduction locally.
    if (lane_ < CTAS_PER_ROW) {
      stats_t result = workspace[lane_];
      n = ELTS_PER_ROW_PER_CTA;
      m = layer_norm::Get<0>::of<stats_t, T>(result);
      m2 = layer_norm::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);

    return {m, m2};
  }

  InterCTASync inter_cta_;
  BlockStats block_stats_;

  stats_t *w0_;
  stats_t *w1_;
  int bidn_;
  int warp_n_;
  int lane_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> {
  using WarpStats = Stats<T, 1, WARPS_M, 1>;
  using stats_t = typename WarpStats::stats_t;

  enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };

  template <typename Params>
  inline __device__ Stats(Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane,
                          void *smem)
      : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) {
    smem0_ = static_cast<stats_t *>(smem) + warp_m * WARPS_N;
    smem1_ = smem0_ + WARPS_M * WARPS_N;
  }

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    stats_t *smem = use0_ ? smem0_ : smem1_;
    use0_ = !use0_;
    // Compute warp local for all WARPS_N
    constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
    stats_t warp_stats = warp_stats_.compute(elts, warp_rn);

    // Each warp warp leader stores its stats
    const auto warp_n = warp_stats_.reducer_.warp_n_;
    const auto lane = warp_stats_.reducer_.lane_;
    if (lane == 0) {
      smem[warp_n] = warp_stats;
    }
    __syncthreads();

    T n = Zeros<T>::get();
    T m = Zeros<T>::get();
    T m2 = Zeros<T>::get();

    // Assume that there are less than 32 warps, such that we can finalize with a single warp
    static_assert(WARPS_N <= 32);
    if (lane < WARPS_N) {
      stats_t result = smem[lane];
      n = N * THREADS_PER_WARP;
      m = layer_norm::Get<0>::of<stats_t, T>(result);
      m2 = layer_norm::Get<1>::of<stats_t, T>(result);
    }

    warp_chan_upd_dynamic(m, m2, n, WARPS_N);

    return {m, m2};
  }
  WarpStats warp_stats_;
  stats_t *smem0_;
  stats_t *smem1_;
  bool use0_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> {
  using stats_t = typename TypeToVec2<T>::Type;
  // The simple Warp reducer.
  using Reducer = Reducer<T, 1, WARPS_M, 1>;

  enum { SMEM_BYTES = 0 };

  template <typename Params>
  inline __device__ Stats(Params &params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane,
                          void *smem)
      : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) {}

  template <uint32_t N>
  inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
    auto sum = Sum<T>();

    T m = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      m += elts[it];
    }
    m = reducer_.allreduce(m, sum) * rn;

    T m2 = Zeros<T>::get();
#pragma unroll
    for (int it = 0; it < N; it++) {
      T diff = (elts[it] - m);
      m2 += diff * diff;
    }
    m2 = reducer_.allreduce(m2, sum);

    return {m, m2};
  }

  Reducer reducer_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace layer_norm
